# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""radio unit tests."""

from typing import Any
from unittest.mock import MagicMock, patch

import numpy as np
import pandas as pd
import pytest
from parameterized import parameterized

import streamlit as st
from streamlit.errors import StreamlitAPIException
from streamlit.proto.LabelVisibilityMessage_pb2 import LabelVisibilityMessage
from streamlit.testing.v1.app_test import AppTest
from streamlit.testing.v1.util import patch_config_options
from tests.delta_generator_test_case import DeltaGeneratorTestCase
from tests.streamlit.data_test_cases import (
    SHARED_TEST_CASES,
    CaseMetadata,
)
from tests.streamlit.elements.layout_test_utils import WidthConfigFields


class RadioTest(DeltaGeneratorTestCase):
    """Test ability to marshall radio protos."""

    def test_just_label(self):
        """Test that it can be called with no value."""
        st.radio("the label", ("m", "f"))

        c = self.get_delta_from_queue().new_element.radio
        assert c.label == "the label"
        assert (
            c.label_visibility.value
            == LabelVisibilityMessage.LabelVisibilityOptions.VISIBLE
        )
        assert c.default == 0
        assert not c.disabled
        assert c.HasField("default")
        assert c.captions == []

    def test_just_disabled(self):
        """Test that it can be called with disabled param."""
        st.radio("the label", ("m", "f"), disabled=True)

        c = self.get_delta_from_queue().new_element.radio
        assert c.disabled

    def test_none_value(self):
        """Test that it can be called with None as index value."""
        st.radio("the label", ("m", "f"), index=None)

        c = self.get_delta_from_queue().new_element.radio
        assert c.label == "the label"
        # If a proto property is null is not determined by this value,
        # but by the check via the HasField method:
        assert c.default == 0
        assert not c.HasField("default")

    def test_horizontal(self):
        """Test that it can be called with horizontal param."""
        st.radio("the label", ("m", "f"), horizontal=True)

        c = self.get_delta_from_queue().new_element.radio
        assert c.horizontal

    def test_horizontal_default_value(self):
        """Test that it can called with horizontal param value False by default."""
        st.radio("the label", ("m", "f"))

        c = self.get_delta_from_queue().new_element.radio
        assert not c.horizontal

    def test_valid_value(self):
        """Test that valid value is an int."""
        st.radio("the label", ("m", "f"), 1)

        c = self.get_delta_from_queue().new_element.radio
        assert c.label == "the label"
        assert c.default == 1

    def test_noneType_option(self):
        """Test NoneType option value."""
        current_value = st.radio("the label", (None, "selected"), 0)

        assert current_value is None

    @parameterized.expand(
        SHARED_TEST_CASES,
    )
    def test_option_types(self, name: str, input_data: Any, metadata: CaseMetadata):
        """Test that it supports different types of options."""
        st.radio("the label", input_data)

        c = self.get_delta_from_queue().new_element.radio
        assert c.label == "the label"
        assert c.default == 0
        assert {str(item) for item in c.options} == {
            str(item) for item in metadata.expected_sequence
        }

    def test_cast_options_to_string(self):
        """Test that it casts options to string."""
        arg_options = ["some str", 123, None, {}]
        proto_options = ["some str", "123", "None", "{}"]

        st.radio("the label", arg_options)

        c = self.get_delta_from_queue().new_element.radio
        assert c.label == "the label"
        assert c.default == 0
        assert c.options == proto_options

    def test_format_function(self):
        """Test that it formats options."""
        arg_options = [{"name": "john", "height": 180}, {"name": "lisa", "height": 200}]
        proto_options = ["john", "lisa"]

        st.radio("the label", arg_options, format_func=lambda x: x["name"])

        c = self.get_delta_from_queue().new_element.radio
        assert c.label == "the label"
        assert c.default == 0
        assert c.options == proto_options

    @parameterized.expand([((),), ([],), (np.array([]),), (pd.Series(np.array([])),)])
    def test_no_options(self, options):
        """Test that it handles no options."""
        st.radio("the label", options)

        c = self.get_delta_from_queue().new_element.radio
        assert c.label == "the label"
        assert (
            c.label_visibility.value
            == LabelVisibilityMessage.LabelVisibilityOptions.VISIBLE
        )
        assert c.default == 0
        assert c.options == []

    def test_invalid_value(self):
        """Test that value must be an int."""
        with pytest.raises(StreamlitAPIException):
            st.radio("the label", ("m", "f"), "1")

    def test_invalid_value_range(self):
        """Test that value must be within the length of the options."""
        with pytest.raises(StreamlitAPIException):
            st.radio("the label", ("m", "f"), 2)

    def test_outside_form(self):
        """Test that form id is marshalled correctly outside of a form."""

        st.radio("foo", ["bar", "baz"])

        proto = self.get_delta_from_queue().new_element.radio
        assert proto.form_id == ""

    @patch("streamlit.runtime.Runtime.exists", MagicMock(return_value=True))
    def test_inside_form(self):
        """Test that form id is marshalled correctly inside of a form."""

        with st.form("form"):
            st.radio("foo", ["bar", "baz"])

        # 2 elements will be created: form block, widget
        assert len(self.get_all_deltas_from_queue()) == 2

        form_proto = self.get_delta_from_queue(0).add_block
        radio_proto = self.get_delta_from_queue(1).new_element.radio
        assert radio_proto.form_id == form_proto.form.form_id

    def test_inside_column(self):
        """Test that it works correctly inside of a column."""
        col1, _col2 = st.columns(2)

        with col1:
            st.radio("foo", ["bar", "baz"])

        all_deltas = self.get_all_deltas_from_queue()

        # 4 elements will be created: 1 horizontal block, 2 columns, 1 widget
        assert len(all_deltas) == 4
        radio_proto = self.get_delta_from_queue().new_element.radio

        assert radio_proto.label == "foo"
        assert radio_proto.options == ["bar", "baz"]
        assert radio_proto.default == 0

    @parameterized.expand(
        [
            ("visible", LabelVisibilityMessage.LabelVisibilityOptions.VISIBLE),
            ("hidden", LabelVisibilityMessage.LabelVisibilityOptions.HIDDEN),
            ("collapsed", LabelVisibilityMessage.LabelVisibilityOptions.COLLAPSED),
        ]
    )
    def test_label_visibility(self, label_visibility_value, proto_value):
        """Test that it can be called with label_visibility param."""
        st.radio("the label", ("m", "f"), label_visibility=label_visibility_value)

        c = self.get_delta_from_queue().new_element.radio
        assert c.label == "the label"
        assert c.default == 0
        assert c.label_visibility.value == proto_value

    def test_label_visibility_wrong_value(self):
        with pytest.raises(StreamlitAPIException) as e:
            st.radio("the label", ("m", "f"), label_visibility="wrong_value")
        assert (
            str(e.value)
            == "Unsupported label_visibility option 'wrong_value'. Valid values are 'visible', 'hidden' or 'collapsed'."
        )

    def test_no_captions(self):
        """Test that it can be called with no captions."""
        st.radio("the label", ("option1", "option2", "option3"), captions=None)

        c = self.get_delta_from_queue().new_element.radio
        assert c.label == "the label"
        assert c.default == 0
        assert c.captions == []

    def test_some_captions(self):
        """Test that it can be called with some captions."""
        st.radio(
            "the label",
            ("option1", "option2", "option3", "option4"),
            captions=("first caption", None, "", "last caption"),
        )

        c = self.get_delta_from_queue().new_element.radio
        assert c.label == "the label"
        assert c.default == 0
        assert c.captions == ["first caption", "", "", "last caption"]

    def test_shows_cached_widget_replay_warning(self):
        """Test that a warning is shown when this widget is used inside a cached function."""
        st.cache_data(lambda: st.radio("the label", ["option 1", "option 2"]))()

        # The widget itself is still created, so we need to go back one element more:
        el = self.get_delta_from_queue(-3).new_element.exception
        assert el.type == "CachedWidgetWarning"
        assert el.is_warning

    def test_radio_with_width(self):
        """Test st.radio with different width types."""
        test_cases = [
            (500, WidthConfigFields.PIXEL_WIDTH.value, "pixel_width", 500),
            ("stretch", WidthConfigFields.USE_STRETCH.value, "use_stretch", True),
            ("content", WidthConfigFields.USE_CONTENT.value, "use_content", True),
        ]

        for index, (
            width_value,
            expected_width_spec,
            field_name,
            field_value,
        ) in enumerate(test_cases):
            with self.subTest(width_value=width_value):
                st.radio(
                    f"test label {index}", ["option 1", "option 2"], width=width_value
                )

                el = self.get_delta_from_queue().new_element
                assert el.radio.label == f"test label {index}"

                assert el.width_config.WhichOneof("width_spec") == expected_width_spec
                assert getattr(el.width_config, field_name) == field_value

    def test_radio_with_invalid_width(self):
        """Test st.radio with invalid width values."""
        test_cases = [
            (
                "invalid",
                "Width must be either a positive integer (pixels), 'stretch', or 'content'.",
            ),
            (
                -100,
                "Width must be either a positive integer (pixels), 'stretch', or 'content'.",
            ),
            (
                0,
                "Width must be either a positive integer (pixels), 'stretch', or 'content'.",
            ),
            (
                100.5,
                "Width must be either a positive integer (pixels), 'stretch', or 'content'.",
            ),
        ]

        for width_value, expected_error_message in test_cases:
            with self.subTest(width_value=width_value):
                with pytest.raises(StreamlitAPIException) as exc:
                    st.radio("test label", ["option 1", "option 2"], width=width_value)

                assert expected_error_message in str(exc.value)

    def test_radio_default_width(self):
        """Test that st.radio defaults to content width."""
        st.radio("test label", ["option 1", "option 2"])

        el = self.get_delta_from_queue().new_element
        assert el.radio.label == "test label"
        assert (
            el.width_config.WhichOneof("width_spec")
            == WidthConfigFields.USE_CONTENT.value
        )
        assert el.width_config.use_content is True

    def test_stable_id_with_key(self):
        """Test that the widget ID is stable when a stable key is provided."""
        with patch(
            "streamlit.elements.lib.utils._register_element_id",
            return_value=MagicMock(),
        ):
            # First render with certain params
            st.radio(
                label="Label 1",
                key="radio_key",
                index=0,
                help="Help 1",
                disabled=False,
                width="content",
                on_change=lambda: None,
                args=("arg1", "arg2"),
                kwargs={"kwarg1": "kwarg1"},
                label_visibility="visible",
                horizontal=False,
                captions=["c1", "c2"],
                # Whitelisted kwargs:
                options=["a", "b"],
                format_func=lambda x: x.capitalize(),
            )
            c1 = self.get_delta_from_queue().new_element.radio
            id1 = c1.id

            # Second render with different non-whitelisted params but same key
            st.radio(
                label="Label 2",
                key="radio_key",
                index=1,
                help="Help 2",
                disabled=True,
                width="stretch",
                on_change=lambda: None,
                args=("arg_1", "arg_2"),
                kwargs={"kwarg_1": "kwarg_1"},
                label_visibility="hidden",
                horizontal=True,
                captions=["c1x", "c2x"],
                # Whitelisted kwargs:
                options=["a", "b"],
                format_func=lambda x: x.capitalize(),
            )
            c2 = self.get_delta_from_queue().new_element.radio
            id2 = c2.id
            assert id1 == id2

    @parameterized.expand(
        [
            (
                "options",
                {"options": ["a", "b"], "format_func": str},
                {"options": ["a", "b", "c"], "format_func": str},
            ),
            (
                "format_func",
                {"options": ["a", "b"], "format_func": str},
                {"options": ["a", "b"], "format_func": str.upper},
            ),
        ]
    )
    def test_whitelisted_stable_key_kwargs(
        self, _name: str, first_kwargs: dict, second_kwargs: dict
    ) -> None:
        """Test that the widget ID changes when a whitelisted kwarg changes even when the key is provided."""
        with patch(
            "streamlit.elements.lib.utils._register_element_id",
            return_value=MagicMock(),
        ):
            st.radio(label="Label 1", key="radio_key2", **first_kwargs)
            c1 = self.get_delta_from_queue().new_element.radio
            id1 = c1.id

            st.radio(label="Label 2", key="radio_key2", **second_kwargs)
            c2 = self.get_delta_from_queue().new_element.radio
            id2 = c2.id
            assert id1 != id2


def test_radio_interaction():
    """Test interactions with an empty radio widget."""

    def script():
        import streamlit as st

        st.radio("the label", ("m", "f"), index=None)

    at = AppTest.from_function(script).run()
    radio = at.radio[0]
    assert radio.value is None

    # Select option m
    at = radio.set_value("m").run()
    radio = at.radio[0]
    assert radio.value == "m"

    # # Clear the value
    at = radio.set_value(None).run()
    radio = at.radio[0]
    assert radio.value is None


def test_radio_enum_coercion():
    """Test E2E Enum Coercion on a radio."""

    def script():
        from enum import Enum

        import streamlit as st

        class EnumA(Enum):
            A = 1
            B = 2
            C = 3

        selected = st.radio("my_enum", EnumA, index=0)
        st.text(id(selected.__class__))
        st.text(id(EnumA))
        st.text(selected in EnumA)

    at = AppTest.from_function(script).run()

    def test_enum():
        radio = at.radio[0]
        original_class = radio.value.__class__
        radio.set_value(original_class.C).run()
        assert at.text[0].value == at.text[1].value, "Enum Class ID not the same"
        assert at.text[2].value == "True", "Not all enums found in class"

    with patch_config_options({"runner.enumCoercion": "nameOnly"}):
        test_enum()
    with (
        patch_config_options({"runner.enumCoercion": "off"}),
        pytest.raises(AssertionError),
    ):
        test_enum()  # expect a failure with the config value off.


def test_None_session_state_value_retained():
    def script():
        import streamlit as st

        if "radio" not in st.session_state:
            st.session_state["radio"] = None

        st.radio("radio", ["a", "b", "c"], key="radio")
        st.button("button")

    at = AppTest.from_function(script).run()
    at = at.button[0].click().run()
    assert at.radio[0].value is None
